"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""

import requests
import hashlib
import os
import torch
import torch.nn as nn
from torchvision import models
from collections import namedtuple
from tqdm import tqdm
from functools import partial
from image_synthesis.modeling.modules.vqgan_loss.moco_net import resnet50
from image_synthesis.modeling.codecs.image_codec.vision_transformer import VisionTransformer 

# from taming.util import get_ckpt_path

class ViTPER(nn.Module):
    # Learned perceptual metric
    def __init__(self, use_dropout=True):
        super().__init__()
        self.scaling_layer = ScalingLayer()        
        #self.moco_r50_model = moco_net.__dict__['resnet50']()
        self.moco_vitb = VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6))
        
        self.moco_vitb.fc = torch.nn.Identity()

        for param in self.moco_vitb.parameters():
            param.requires_grad = False

        self.trainable = False
    # def train(self, mode=True):
    #     pass

    def train(self, mode=True):
        if self.trainable and mode:
            return super().train(True)
        else:
            return super().train(False)


    def load_from_pretrained(self, name="vgg_lpips"):
        ckpt = get_ckpt_path(name)
        self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
        print("loaded pretrained LPIPS loss from {}".format(ckpt))

    @classmethod
    def from_pretrained(cls, name="vgg_lpips"):
        if name != "vgg_lpips":
            raise NotImplementedError
        model = cls()
        ckpt = get_ckpt_path(name)
        model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
        return model

    def forward(self, input, target):
        in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
        outs0, outs1 = self.moco_vitb(in0_input, index_list=[3, 6, 9, 12]), self.moco_vitb(in1_input, index_list=[3, 6, 9, 12])
        diffs = {}        
        for kk in range(len(outs0)):            
            diffs[kk] = (outs0[kk] - outs1[kk]) ** 2

        res = [spatial_average(diffs[kk], keepdim=True) for kk in range(len(diffs))]
        val = res[0]
        for l in range(1, len(diffs)):
            val += res[l]
        return val


class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])

    def forward(self, inp):
        height = inp.shape[2]
        if height > 256:
            inp = nn.functional.interpolate(inp, size=(256, 256))
        return (inp - self.shift) / self.scale



def normalize_tensor(x,eps=1e-10):
    """
    Get the norm along channel dimension
    """
    norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
    return x/(norm_factor+eps)


def spatial_average(x, keepdim=True):
    #return x.mean([2,3],keepdim=keepdim)
    x = x.unsqueeze(3)
    return x.mean([1,2,3],keepdim=keepdim)

